{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gentrl\n",
    "import torch\n",
    "torch.cuda.set_device(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = gentrl.RNNEncoder(latent_size=50)\n",
    "dec = gentrl.DilConvDecoder(latent_input_size=50)\n",
    "model = gentrl.GENTRL(enc, dec, 50 * [('c', 20)], [('c', 20)], beta=0.001)\n",
    "model.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load('saved_gentrl/')\n",
    "model.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from moses.metrics import mol_passes_filters, QED, SA, logP\n",
    "from moses.metrics.utils import get_n_rings, get_mol\n",
    "\n",
    "from moses.utils import disable_rdkit_log\n",
    "disable_rdkit_log()\n",
    "\n",
    "def get_num_rings_6(mol):\n",
    "    r = mol.GetRingInfo()\n",
    "    return len([x for x in r.AtomRings() if len(x) > 6])\n",
    "\n",
    "\n",
    "def penalized_logP(mol_or_smiles, masked=False, default=-5):\n",
    "    mol = get_mol(mol_or_smiles)\n",
    "    if mol is None:\n",
    "        return default\n",
    "    reward = logP(mol) - SA(mol) - get_num_rings_6(mol)\n",
    "    if masked and not mol_passes_filters(mol):\n",
    "        return default\n",
    "    return reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.train_as_rl(penalized_logP)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "! mkdir -p saved_gentrl_after_rl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save('./saved_gentrl_after_rl/')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
